#!/bin/env python
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap
import matplotlib.cm as cm
from netCDF4 import Dataset

#----------
# Parameters
#----------
season = "MJJASO"
syear1 = 1977
eyear1 = 2015
syear2 = 1977
eyear2 = 2050
syear3 = 1977
eyear3 = 2100

cat1 = "TCS" 
cat2 = "C345" 
cat3 = "HUR" 

#----------
# Read Data
#----------
# Obs
# 1977-2015
data_obs1 = {}
for cat in [cat1,cat2]:
    infl="./Data/trend_observations_%(season)s_%(cat)s.nc" % {"syear":syear1,"eyear":eyear1,"season":season,"cat":cat}
    f12 = Dataset(infl,"r")
    
    data_obs1[cat]={
         "slope": f12.variables['slope'][:,:].squeeze(),
          "pval": f12.variables['pval'][:,:].squeeze(),
           "lon": f12.variables['lon'][:],
           "lat": f12.variables['lat'][:],
           "syear":syear1,
           "eyear":eyear1,
    }
    f12.close()

# HiFLOR AllForc
# 1977-2015
data_hiflor_all1 = {}
for cat in [cat1,cat2]:
    infl="./Data/trend_HiFLOR_AllForc_%(season)s_%(cat)s.nc" % {"syear":syear1,"eyear":eyear1,"season":season,"cat":cat}
    f12 = Dataset(infl,"r")
    data_hiflor_all1[cat]={
         "slope": f12.variables['slope'][:,:].squeeze(),
          "pval": f12.variables['pval'][:,:].squeeze(),
           "lon": f12.variables['lon'][:],
           "lat": f12.variables['lat'][:],
           "syear":syear1,
           "eyear":eyear1,
    }
    f12.close()

# HiFLOR NatForc
# 1977-2015
data_hiflor_nat1 = {}
for cat in [cat1,cat2]:
    infl="./Data/trend_HiFLOR_NatForc_%(season)s_%(cat)s.nc" % {"syear":syear1,"eyear":eyear1,"season":season,"cat":cat}
    f12 = Dataset(infl,"r")
    data_hiflor_nat1[cat]={
         "slope": f12.variables['slope'][:,:].squeeze(),
          "pval": f12.variables['pval'][:,:].squeeze(),
           "lon": f12.variables['lon'][:],
           "lat": f12.variables['lat'][:],
           "syear":syear1,
           "eyear":eyear1,
    }
    f12.close()

# HiFLOR
# 1977-2050
data_hiflor_all2 = {}
for cat in [cat1,cat2]:
    infl="./Data/trend_HiFLOR_AllForc_%(season)s_%(cat)s_future.nc" % {"syear":syear1,"eyear":eyear1,"season":season,"cat":cat}
    f12 = Dataset(infl,"r")
    data_hiflor_all2[cat]={
         "slope": f12.variables['slope'][:,:].squeeze(),
          "pval": f12.variables['pval'][:,:].squeeze(),
           "lon": f12.variables['lon'][:],
           "lat": f12.variables['lat'][:],
           "syear":syear2,
           "eyear":eyear2,
    }
    f12.close()

#----------
# Plot
#----------
fig = plt.figure(figsize=(19,6)) # set figure environemnt

ddomain2 = [100,180,0,50]
cmap2=cm.coolwarm # set colormap
siglev=0.10

#----
# Obs
# TCS (1971-2015)
#----
ax = fig.add_subplot(1,3,1) # set up panel plot
cat = "C345"
data = data_obs1[cat]
flon,flat = np.meshgrid(data['lon'],data['lat']) # meshgrid
m = Basemap(projection='mill',llcrnrlat=ddomain2[2], urcrnrlat=ddomain2[3], llcrnrlon=ddomain2[0],urcrnrlon=ddomain2[1], 
             lat_ts=10, resolution='l') # draw basemap
x,y=m(flon,flat) # lon,lat => x,y
    
m.drawcoastlines() # draw coast line

contours2 = np.arange(-0.8,0.9,0.1)
cs = m.contourf(x,y,data['slope']*365*10, contours2, cmap=cmap2, extend="both") # plot contours

jmax,imax = np.shape(flon)
for jj in range(jmax):
    for ii in range(jmax):
        if data['pval'][jj,ii] <= siglev: 
            x,y=m(flon[jj,ii],flat[jj,ii]) # lon,lat => x,y
            if data['slope'][jj,ii]>=0:
                m.plot(x,y,'o', ms=5, color='orange')
            elif data['slope'][jj,ii]<0:
                m.plot(x,y,'o', ms=5, color='cyan')

#
plt.title("(c) Trends in MHF (Observations,  %4.4i\u2014%4.4i)" % (data['syear'],data['eyear']), fontsize=15)

meridians = np.arange(0.,360.,10.)
m.drawmeridians(meridians,labels=[0,0,0,1],linewidth=0, fontsize=10)
parallels = np.arange(0.,90, 10.)
m.drawparallels(parallels,labels=[1,0,0,1],linewidth=0, fontsize=10)

# add legend
legend_posi2 = [0.025, 0.10, 0.30, 0.030]
cax = fig.add_axes(legend_posi2)
art = plt.colorbar(cs, cax, orientation='horizontal')
art.set_label('Trend in MHF [number per decade]', fontsize=12)
art.ax.tick_params(labelsize=10)  

#----
# HiFLOR AllForc
# TCS (1971-2015)
#----
ax = fig.add_subplot(1,3,2) # set up panel plot
cat = "C345"
data = data_hiflor_all1[cat]
flon,flat = np.meshgrid(data['lon'],data['lat']) # meshgrid
m = Basemap(projection='mill',llcrnrlat=ddomain2[2], urcrnrlat=ddomain2[3], llcrnrlon=ddomain2[0],urcrnrlon=ddomain2[1], 
             lat_ts=10, resolution='l') # draw basemap
x,y=m(flon,flat) # lon,lat => x,y
    
m.drawcoastlines() # draw coast line

contours2 = np.arange(-0.4,0.45,0.05)
cs = m.contourf(x,y,data['slope']*365*10, contours2, cmap=cmap2, extend="both") # plot contours

jmax,imax = np.shape(flon)
for jj in range(jmax):
    for ii in range(jmax):
        if data['pval'][jj,ii] <= siglev: 
            x,y=m(flon[jj,ii],flat[jj,ii]) # lon,lat => x,y
            if data['slope'][jj,ii]>=0:
                m.plot(x,y,'o', ms=5, color='orange')
            elif data['slope'][jj,ii]<0:
                m.plot(x,y,'o', ms=5, color='cyan')

plt.title("(d) Trends in MHF (HiFLOR, AllForc, %4.4i\u2014%4.4i)" % (data['syear'],data['eyear']), fontsize=15)

meridians = np.arange(0.,360.,10.)
m.drawmeridians(meridians,labels=[0,0,0,1],linewidth=0, fontsize=10)
parallels = np.arange(0.,90, 10.)
m.drawparallels(parallels,labels=[1,0,0,1],linewidth=0, fontsize=10)

# add legend
legend_posi2 = [0.360, 0.10, 0.30, 0.030]
cax = fig.add_axes(legend_posi2)
art = plt.colorbar(cs, cax, orientation='horizontal')
art.set_label('Trend in MHF [number per decade]', fontsize=12)
art.ax.tick_params(labelsize=10)  

fig.tight_layout()
plt.savefig("Fig9cd.png")
